import cv2
import math
import numpy as np
import torch

# https://github.com/He-Zhang/image_dehaze/blob/master/dehaze.py

def DarkChannel(im, sz):
    b, g, r = cv2.split(im)
    dc = cv2.min(cv2.min(r, g), b)
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (sz, sz))
    dark = cv2.erode(dc, kernel)
    return dark


def AtmLight(im, dark):
    [h, w] = im.shape[:2]
    imsz = h * w
    numpx = int(max(math.floor(imsz / 1000), 1))
    darkvec = dark.reshape(imsz)
    imvec = im.reshape(imsz, 3)

    indices = darkvec.argsort()
    indices = indices[imsz - numpx::]

    atmsum = np.zeros([1, 3])
    for ind in range(1, numpx):
        atmsum = atmsum + imvec[indices[ind]]

    A = atmsum / numpx
    return A


def TransmissionEstimate(im, A, sz):
    omega = 0.95
    im3 = np.empty(im.shape, im.dtype)

    for ind in range(0, 3):
        im3[:, :, ind] = im[:, :, ind] / A[0, ind]

    transmission = 1 - omega * DarkChannel(im3, sz)
    return transmission


def Guidedfilter(im, p, r, eps):
    mean_I = cv2.boxFilter(im, cv2.CV_64F, (r, r))
    mean_p = cv2.boxFilter(p, cv2.CV_64F, (r, r))
    mean_Ip = cv2.boxFilter(im * p, cv2.CV_64F, (r, r))
    cov_Ip = mean_Ip - mean_I * mean_p

    mean_II = cv2.boxFilter(im * im, cv2.CV_64F, (r, r))
    var_I = mean_II - mean_I * mean_I

    a = cov_Ip / (var_I + eps)
    b = mean_p - a * mean_I

    mean_a = cv2.boxFilter(a, cv2.CV_64F, (r, r))
    mean_b = cv2.boxFilter(b, cv2.CV_64F, (r, r))

    q = mean_a * im + mean_b
    return q


def TransmissionRefine(im, et):
    gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
    gray = np.float64(gray) / 255
    r = 60
    eps = 0.0001
    t = Guidedfilter(gray, et, r, eps)

    return t


def Recover(im, t, A, tx=0.1):
    res = np.empty(im.shape, im.dtype)
    t = cv2.max(t, tx)

    for ind in range(0, 3):
        res[:, :, ind] = (im[:, :, ind] - A[0, ind]) / t + A[0, ind]

    return res


def dcp_dehazing(I):
    dark = DarkChannel(I, 15)
    A = AtmLight(I, dark)
    te = TransmissionEstimate(I, A, 15)
    t = TransmissionRefine(I, te)
    J = Recover(I, t, A, 0.1)

    return J


def batch_dcp_dehazing(batch_I):
    batch_J = torch.zeros(batch_I.size())
    for idx in range(batch_I.size(0)):
        cur_I = batch_I[idx].cpu().detach().numpy().transpose(1, 2, 0)
        cur_I = (cur_I + 1) / 2
        cur_J = dcp_dehazing(cur_I)

        cur_tensor_J = (cur_J.transpose(2, 0, 1) * 2) - 1
        cur_tensor_J = torch.from_numpy(cur_tensor_J)

        batch_J[idx] = cur_tensor_J.clone()

    return batch_J.cuda()


if __name__ == '__main__':
    fn = "../../dataset/RESIDE/ITS/hazy/1_10_0.98796.png"
    src = cv2.imread(fn, -1)
    print(src.shape)
    # src = cv2.resize(src, (256, 256))
    I = src.astype('float32') / 255
    J = dcp_dehazing(I)
    cv2.imshow("0", J)
    cv2.waitKey(0)

    # x = torch.zeros(2, 3, 256, 256)
    # batch_dcp_dehazing(x)

